#!/usr/bin/env python3
"""
Phase 4: Leave-One-Out Country Screen
Identifies influential countries that shift Z₁ coefficient > 1 SE when dropped.
Computes approximate Cook's distance.
"""

import sys
import pandas as pd
import numpy as np
from pathlib import Path

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows")
sys.path.insert(0, str(PROJECT_DIR / "multilateral" / "src"))
from model import PanelGLS

FOLLOWUP_DIR = PROJECT_DIR / "multilateral" / "followup"
TABLE_DIR = PROJECT_DIR / "fragility" / "output" / "tables"
TABLE_DIR.mkdir(parents=True, exist_ok=True)

print("=" * 70)
print("PHASE 4: LEAVE-ONE-OUT COUNTRY SCREEN")
print("=" * 70)

# ── Load panel ─────────────────────────────────────────────────────────────
panel = pd.read_csv(FOLLOWUP_DIR / "data" / "processed" / "full_panel.csv")
panel = panel[(panel['year'] >= 1986) & (panel['year'] <= 2024)].copy()

# ── OECD membership ───────────────────────────────────────────────────────
OECD_MEMBERS = {
    'AUS', 'AUT', 'BEL', 'CAN', 'CHL', 'COL', 'CRI', 'CZE', 'DNK', 'EST',
    'FIN', 'FRA', 'DEU', 'GRC', 'HUN', 'ISL', 'IRL', 'ISR', 'ITA', 'JPN',
    'KOR', 'LVA', 'LTU', 'LUX', 'MEX', 'NLD', 'NZL', 'NOR', 'POL', 'PRT',
    'SVK', 'SVN', 'ESP', 'SWE', 'CHE', 'TUR', 'GBR', 'USA'
}


def loo_screen(dep_var, indep_vars, data, label, z1_idx=0):
    """
    Leave-one-out screen: drop each country and record Z₁ coefficient.
    Flag countries where |Z₁_full - Z₁_loo| > 1 SE.
    """
    comp = data.dropna(subset=[dep_var] + indep_vars).copy()
    if len(comp) < 100:
        print(f"  SKIP {label}: only {len(comp)} obs")
        return None

    # Full sample
    gls = PanelGLS()
    gls.fit(comp[dep_var].values, comp[indep_vars].values,
            comp['iso3'].values, comp['year'].values)

    z1_full = gls.beta[z1_idx]
    z1_se = gls.se[z1_idx]
    full_n = gls.n_obs
    full_r2 = gls.r_squared

    print(f"\n  {label}: Z₁ = {z1_full:.3f} (SE={z1_se:.3f}, p={gls.pvalues[z1_idx]:.4f})")
    print(f"  Full sample: {full_n} obs, {comp['iso3'].nunique()} countries, R²={full_r2:.4f}")

    # LOO
    countries = sorted(comp['iso3'].unique())
    rows = []

    for iso in countries:
        sub = comp[comp['iso3'] != iso].copy()
        if sub['iso3'].nunique() < 10:
            continue

        gls_loo = PanelGLS()
        gls_loo.fit(sub[dep_var].values, sub[indep_vars].values,
                    sub['iso3'].values, sub['year'].values)

        z1_loo = gls_loo.beta[z1_idx]
        delta = z1_full - z1_loo
        n_dropped = len(comp[comp['iso3'] == iso])

        # Approximate Cook's distance for Z₁
        # D_i ≈ (β_full - β_loo)² / (k * MSE)
        k = len(indep_vars)
        mse = (1 - full_r2) * comp[dep_var].var()
        cook_d = (delta ** 2) / (k * mse) if mse > 0 else np.nan

        rows.append({
            'iso3': iso,
            'n_obs_dropped': n_dropped,
            'is_oecd': iso in OECD_MEMBERS,
            'Z1_loo': z1_loo,
            'Z1_delta': delta,
            'Z1_delta_se': abs(delta) / z1_se if z1_se > 0 else np.nan,
            'p_loo': gls_loo.pvalues[z1_idx],
            'R2_loo': gls_loo.r_squared,
            'cook_d': cook_d,
            'influential': abs(delta) > z1_se,
        })

    loo_df = pd.DataFrame(rows)
    loo_df['label'] = label
    loo_df['Z1_full'] = z1_full
    loo_df['Z1_se_full'] = z1_se

    # Report influential countries
    influential = loo_df[loo_df['influential']].sort_values('Z1_delta_se', ascending=False)
    print(f"\n  Influential countries ({len(influential)}/{len(loo_df)}):")
    for _, r in influential.head(15).iterrows():
        direction = "↑" if r['Z1_delta'] > 0 else "↓"
        oecd = "OECD" if r['is_oecd'] else ""
        print(f"    {r['iso3']:>4} {oecd:>4}: ΔZ₁={r['Z1_delta']:+.3f} ({r['Z1_delta_se']:.1f}σ) "
              f"{direction}  Cook's D={r['cook_d']:.4f}  n={r['n_obs_dropped']}")

    # Top 10 by Cook's distance
    top_cook = loo_df.nlargest(10, 'cook_d')
    print(f"\n  Top 10 by Cook's distance:")
    for _, r in top_cook.iterrows():
        oecd = "OECD" if r['is_oecd'] else ""
        print(f"    {r['iso3']:>4} {oecd:>4}: Cook's D={r['cook_d']:.4f}  ΔZ₁={r['Z1_delta']:+.3f}")

    return loo_df


# ════════════════════════════════════════════════════════════════════════════
# TEST 1: CA/GDP baseline (demographics → current account)
# ════════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("TEST 1: Z₁ → CA/GDP (demographic baseline)")
print("=" * 70)

ca_vars = ['Z_1', 'Z_2', 'Z_3', 'fiscal_bal_gdp', 'kaopen', 'nfa_gdp_lag', 'log_rel_opw']
avail = [v for v in ca_vars if v in panel.columns]
loo1 = loo_screen('ca_gdp', avail, panel, 'Z→CA baseline')


# ════════════════════════════════════════════════════════════════════════════
# TEST 2: Bond yields (demographics → rates, OECD only)
# ════════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("TEST 2: Z₁ → govt_bond_10y (OECD)")
print("=" * 70)

if 'govt_bond_10y' in panel.columns:
    oecd_panel = panel[panel['iso3'].isin(OECD_MEMBERS)].copy()
    rate_vars = ['Z_1', 'Z_2', 'Z_3', 'fiscal_bal_gdp', 'kaopen', 'rgdp_growth']
    avail_rate = [v for v in rate_vars if v in oecd_panel.columns]
    loo2 = loo_screen('govt_bond_10y', avail_rate, oecd_panel, 'Z→10y yield (OECD)')
else:
    print("  govt_bond_10y not available")
    loo2 = None


# ════════════════════════════════════════════════════════════════════════════
# TEST 3: Investment/GDP (demographics → investment)
# ════════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("TEST 3: Z₁ → gross_investment_gdp")
print("=" * 70)

iy_col = None
for col in ['gross_investment_gdp', 'gross_fixed_investment_gdp']:
    if col in panel.columns:
        iy_col = col
        break

if iy_col:
    iy_vars = ['Z_1', 'Z_2', 'Z_3', 'fiscal_bal_gdp', 'kaopen', 'nfa_gdp_lag']
    avail_iy = [v for v in iy_vars if v in panel.columns]
    loo3 = loo_screen(iy_col, avail_iy, panel, f'Z→{iy_col}')
else:
    loo3 = None


# ════════════════════════════════════════════════════════════════════════════
# Combine and save
# ════════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("COMBINED LOO RESULTS")
print("=" * 70)

all_loo = []
for loo in [loo1, loo2, loo3]:
    if loo is not None:
        all_loo.append(loo)

if all_loo:
    combined = pd.concat(all_loo, ignore_index=True)
    combined.to_csv(TABLE_DIR / "table4_loo_country_screen.csv", index=False)

    # Summary: most influential countries across all tests
    influential_counts = combined[combined['influential']].groupby('iso3').size()
    influential_counts = influential_counts.sort_values(ascending=False)

    print(f"\nCountries influential in multiple tests:")
    for iso, count in influential_counts.head(20).items():
        oecd = "OECD" if iso in OECD_MEMBERS else ""
        tests = combined[(combined['iso3'] == iso) & (combined['influential'])]['label'].tolist()
        print(f"  {iso:>4} {oecd:>4}: influential in {count} test(s) — {', '.join(tests)}")

    # Markdown summary
    md = ["# Table 4: Leave-One-Out Country Screen\n"]
    md.append("Countries where dropping shifts Z₁ by more than 1 standard error.\n")

    for label in combined['label'].unique():
        sub = combined[combined['label'] == label]
        n_inf = sub['influential'].sum()
        n_total = len(sub)
        full_z1 = sub['Z1_full'].iloc[0]
        full_se = sub['Z1_se_full'].iloc[0]

        md.append(f"\n## {label}\n")
        md.append(f"Full sample Z₁ = {full_z1:.3f} (SE = {full_se:.3f})")
        md.append(f"Influential countries: {n_inf}/{n_total} ({n_inf/n_total*100:.0f}%)\n")

        inf = sub[sub['influential']].nlargest(10, 'cook_d')
        md.append("| Country | OECD | N dropped | ΔZ₁ | ΔZ₁/SE | Cook's D | p (LOO) |")
        md.append("|:---|:---|---:|---:|---:|---:|---:|")
        for _, r in inf.iterrows():
            oecd = "Yes" if r['is_oecd'] else "No"
            md.append(f"| {r['iso3']} | {oecd} | {r['n_obs_dropped']} | {r['Z1_delta']:+.3f} | "
                      f"{r['Z1_delta_se']:.1f}σ | {r['cook_d']:.4f} | {r['p_loo']:.4f} |")

    md.append(f"\n## Cross-Test Influence\n")
    md.append("Countries influential in 2+ tests:\n")
    multi = influential_counts[influential_counts >= 2]
    if len(multi) > 0:
        md.append("| Country | OECD | Tests influential |")
        md.append("|:---|:---|---:|")
        for iso, count in multi.items():
            oecd = "Yes" if iso in OECD_MEMBERS else "No"
            md.append(f"| {iso} | {oecd} | {count} |")
    else:
        md.append("None found across the tested specifications.\n")

    md_text = "\n".join(md)
    (TABLE_DIR / "table4_loo_country_screen.md").write_text(md_text)
    print(f"\n{md_text[:2000]}...")  # Print first part

print(f"\nSaved results to {TABLE_DIR}")
print("Phase 4 complete.")
